import random
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import numpy as np
import json
import matplotlib.pyplot as plt
import argparse
import os
import torch

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def parse_arguments():
    parser = argparse.ArgumentParser(description="Zero-shot-CoT for ScienceQA Dataset with Train Filtering")
    parser.add_argument(
        "--task", type=str, default="scienceqa",
        help="dataset used for experiment"
    )
    parser.add_argument(
        "--max_ra_len", type=int, default=5,
        help="maximum number of reasoning chains"
    )
    parser.add_argument(
        "--problems_file", type=str, default="/home/test/yxl/MCoT/data/aokvqa/aokvqa_v1p0_train.json",
        help="path to the problems json file"
    )
    parser.add_argument(
        "--demo_save_dir", type=str, default="/home/test/yxl/MCoT/aokvqa/tool",
        help="where to save the constructed demonstrations"
    )
    parser.add_argument("--random_seed", type=int, default=192, help="random seed")
    parser.add_argument(
        "--encoder", type=str, default="all-MiniLM-L6-v2",
        help="which sentence-transformer encoder for clustering"
    )
    parser.add_argument(
        "--sampling", type=str, default="center",
        help="whether to sample the cluster center first"
    )
    parser.add_argument(
        "--debug", type=bool, default=True, help="debug mode"
    )
    parser.add_argument(
        "--num_clusters", type=int, default=4,
        help="number of clusters for KMeans"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_arguments()
    fix_seed(args.random_seed)
    encoder = SentenceTransformer(args.encoder)

    task = args.task
    problems_file = args.problems_file
    save_file = args.demo_save_dir
    max_ra_len = args.max_ra_len
    num_clusters = args.num_clusters


    # 确保保存目录存在
    os.makedirs(save_file, exist_ok=True)

    # 读取数据集
    print(f"Reading problems from {problems_file}")
    with open(problems_file, "r", encoding="utf-8") as fp:
        problems = json.load(fp)



    problems = {item['question_id']: item for item in problems}



    # 构建语料库和数据列表
    corpus = []
    questions = []
    choices = []
    rational = []
    gold_ans = []
    qids = []  # 保存原始qid

    selected_pids = list(problems.keys())

    for qid in selected_pids:
        if qid not in problems:
            print(f"Warning: QID {qid} not found in problems data. Skipping...")
            continue

        problem = problems[qid]


        full_question = f"Q: {problem['question']}\nChoices: {', '.join(problem['choices'])}\n"
        questions.append(full_question)

        # 添加solution作为推理依据
        if problem.get('rationales') and len(problem['rationales']) > 0:
            rational.append(problem['rationales'][0])
        else:
            print(f"Warning: QID {qid} has no rationales. Using empty string.")
            rational.append("")


        # 保存答案
        gold_ans.append(problem['correct_choice_idx'])

        corpus_text = f"{problem['question']}{' '.join(problem['choices'])} {rational[-1]}"
        corpus.append(corpus_text)

        # 保存选项用于后续处理
        choices.append(problem['choices'])
        qids.append(qid)  # 保存原始qid

    print(f"Total samples after filtering: {len(corpus)}")

    # 编码语料库
    print("Encoding corpus with SentenceTransformer...")
    corpus_embeddings = encoder.encode(corpus)

    # 执行KMeans聚类
    print(f"Performing KMeans clustering with {num_clusters} clusters...")
    clustering_model = KMeans(n_clusters=num_clusters, random_state=args.random_seed)
    clustering_model.fit(corpus_embeddings)
    cluster_assignment = clustering_model.labels_

    # 组织聚类结果
    clustered_sentences = [[] for i in range(num_clusters)]
    clustered_dists = [[] for i in range(num_clusters)]
    clustered_idx = [[] for i in range(num_clusters)]
    dist = clustering_model.transform(corpus_embeddings)

    for sentence_id, cluster_id in enumerate(cluster_assignment):
        clustered_sentences[cluster_id].append(corpus[sentence_id])
        clustered_dists[cluster_id].append(dist[sentence_id][cluster_id])
        clustered_idx[cluster_id].append(sentence_id)

    # 构建演示示例
    demos = []
    print("Building demonstration examples...")

    for i in range(len(clustered_dists)):
        print(f"Processing Cluster {i + 1}/{num_clusters}")
        # 按距离排序，找到最接近簇中心的样本
        tmp = list(map(list, zip(range(len(clustered_dists[i])), clustered_dists[i])))
        top_min_dist = sorted(tmp, key=lambda x: x[1], reverse=False)

        # 可选的随机打乱
        if not args.sampling == "center":
            random.shuffle(top_min_dist)

        for element in top_min_dist:
            min_idx = element[0]
            idx_in_corpus = clustered_idx[i][min_idx]

            # 提取相关信息
            c_question = questions[idx_in_corpus]
            c_choices = choices[idx_in_corpus]
            c_rational = rational[idx_in_corpus]
            c_gold_ans = gold_ans[idx_in_corpus]
            c_qid = qids[idx_in_corpus]  # 使用保存的原始qid

            if not c_rational:
                print(f"Skipping demo from Cluster {i + 1} with QID: {c_qid}. Rationale is empty.")
                continue

            c_rational = c_rational.replace("\n\n", "\n").replace("\n", " ").strip()
            c_rational = " ".join(c_rational.split())

            # 检查solution长度
            if c_rational and len(c_rational.split("\n")) <= max_ra_len and c_rational[-1] in [".", "!", "?"]:
                demo_element = {
                    "qid": c_qid,
                    "question": c_question,
                    "choices": c_choices,
                    "rationale": c_rational,
                    "gold_ans": c_gold_ans,
                    "cluster_id": i
                }
                demos.append(demo_element)
                print(f"Added demo from Cluster {i + 1} with QID: {c_qid}")
                print(f"Question: {c_question}")
                print(f"Rationale: {c_rational}")
                print(f"Answer: {c_gold_ans}")
                print("---")
                break  # 每个簇只取一个样本

    # 保存演示示例
    demo_output_file = os.path.join(save_file, f"demos_train.json")
    print(f"Saving {len(demos)} demonstrations to {demo_output_file}")
    with open(demo_output_file, 'w', encoding="utf-8") as write_f:
        json.dump({"demos": demos}, write_f, indent=4, ensure_ascii=False)

    # 可视化聚类结果
    print("Visualizing clustering results...")
    y_km = clustering_model.fit_predict(corpus_embeddings)
    pca_model = PCA(n_components=2, random_state=args.random_seed)
    transformed = pca_model.fit_transform(corpus_embeddings)
    centers = pca_model.transform(clustering_model.cluster_centers_)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(x=transformed[:, 0], y=transformed[:, 1],
                          c=y_km, s=50, cmap=plt.cm.Paired, alpha=0.4)
    plt.scatter(centers[:, 0], centers[:, 1],
                s=250, marker='*', label='centroids',
                edgecolor='black',
                c=np.arange(0, num_clusters), cmap=plt.cm.Paired)
    plt.xticks([])
    plt.yticks([])
    plt.title(f"KMeans Clustering for A-OKVQA train Split (n={num_clusters})")
    plt.colorbar(scatter, label='Cluster Label')

    # 保存聚类可视化
    vis_output_file = os.path.join(save_file, f"clustering_train.png")
    plt.savefig(vis_output_file, dpi=600, bbox_inches='tight')
    print(f"Clustering visualization saved to {vis_output_file}")


if __name__ == "__main__":
    main()